import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch.nn.functional as F
#from utils import train
import torchvision.models as models
import time
import matplotlib.pyplot as plt
import cvxpy as cvx
import scipy.io as scio
time_start=time.time()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#print(device)

x1=np.load('./data/o31_feature/feature_a_4096.npy')
y1=np.load('./data/o31_feature/label_a.npy')
x2=np.load('./data/o31_feature/feature_d_4096.npy')
y2=np.load('./data/o31_feature/label_d.npy')
x3=np.load('./data/o31_feature/feature_w_4096.npy')
y3=np.load('./data/o31_feature/label_w.npy')

#print('data ok')  

index=np.random.choice(x2[y2==0].shape[0],3,replace=False)
refx=x2[y2==0][index]
refy=[0,0,0]
for i in range(1,31):
    index=np.random.choice(x2[y2==i].shape[0],3,replace=False)
    refx=np.vstack((refx,x2[y2==i][index]))
    refy=np.append(refy,[i,i,i])
refy=refy.astype(int)

index=np.random.choice(x1[y1==0].shape[0],20,replace=False)
sourcex=x1[y1==0][index]
sourcey=0*np.ones(20)
for i in range(1,31):
    index=np.random.choice(x1[y1==i].shape[0],20,replace=False)
    sourcex=np.vstack((sourcex,x1[y1==i][index]))
    sourcey=np.append(sourcey,i*np.ones(20))
sourcey=sourcey.astype(int)

#feature---------------
class Net_f(nn.Module):
    def __init__(self):
        super(Net_f, self).__init__()
        self.fc1 = nn.Linear(4096,1024)
        self.fc2 = nn.Linear(1024,64)

    def forward(self,x):
        out=F.relu(self.fc1(x))
        out=self.fc2(out)
        return out       


class Net_g(nn.Module):
    def __init__(self,num_class=31, dim=64):
        super(Net_g, self).__init__()

        self.fc=nn.Linear(num_class, dim)

    def forward(self,x):
        out=self.fc(x)

        return out

def corr(f,g):
    k = torch.mean(torch.sum(f*g,1))
    return k
    
def cov_trace(f,g):
    cov_f = torch.mm(torch.t(f),f) / (f.size()[0]-1.)
    cov_g = torch.mm(torch.t(g),g) / (g.size()[0]-1.)
    return torch.trace(torch.mm(cov_f, cov_g))

def neg_hscore(f,g):
    f0 = f - torch.mean(f,0)
    g0 = g - torch.mean(g,0)
    corr = torch.mean(torch.sum(f0*g0,1))
    cov_f = torch.mm(torch.t(f0),f0) / (f0.size()[0]-1.)
    cov_g = torch.mm(torch.t(g0),g0) / (g0.size()[0]-1.)
    return - corr + torch.trace(torch.mm(cov_f, cov_g)) / 2.

lr=0.0002
epoch=100
ind=0
model_f = Net_f().to(device)
model_g = Net_g().to(device)
optimizer_fg = torch.optim.Adam(list(model_f.parameters())+list(model_g.parameters()),lr=lr)
losslist=[]
acclist=[0]
alpha=[0.9,0.1]

samples_ref=torch.from_numpy(refx)
labels_ref=torch.from_numpy(refy)
labels_one_hot_ref = torch.zeros(len(labels_ref), 31).scatter_(1, labels_ref.view(-1,1), 1)
samples_trans=torch.from_numpy(sourcex)
labels_trans=torch.from_numpy(sourcey)
labels_one_hot_trans= torch.zeros(len(labels_trans), 31).scatter_(1, labels_trans.view(-1,1), 1)

for i in range(epoch):
    model_f.train()
    model_g.train()
    
    f_ref=model_f(Variable(samples_ref).float().to(device))
    g_ref=model_g(Variable(labels_one_hot_ref).float().to(device))
    f0_ref = f_ref - torch.mean(f_ref,0)
    g0_ref = g_ref - torch.mean(g_ref,0)
#    f_trans=model_f(Variable(samples_trans).float().to(device))
#    g_trans=model_g(Variable(labels_one_hot_trans).float().to(device))
#    f_trans=f_trans-torch.mean(f_trans,0)
#    g_trans=g_trans-torch.mean(g_trans,0)
    f_trans=model_f(Variable(samples_trans).float().to(device))-torch.mean(f_ref,0)
    g_trans=model_g(Variable(labels_one_hot_trans).float().to(device))- torch.mean(g_ref,0)
    optimizer_fg.zero_grad()
    
    loss=(-2)*alpha[0]*corr(f0_ref,g0_ref)
    loss+=(-2)*alpha[1]*corr(f_trans,g_trans)
    loss+=cov_trace(f0_ref,g0_ref)
    losslist.append(loss.item())
    loss.backward()
    optimizer_fg.step()
    ind+=1
#    print(ind)
#------acc
    model_f.eval()
    model_g.eval()
    fc = model_f(Variable(samples_trans).float().to(device)).data.cpu().numpy()
    f_mean = np.sum(fc,axis=0)/fc.shape[0]
    labellist = torch.Tensor(np.eye(31))
    gc = model_g(Variable(labellist).to(device)).data.cpu().numpy()
    gce = np.sum(gc,axis=0)/gc.shape[0]
    gcp = gc-gce

    samples_test=torch.from_numpy(x2)
    labels_test = y2
    fc=model_f(Variable(samples_test).float().to(device)).data.cpu().numpy()
    fcp=fc-f_mean
    fgp=np.dot(fcp,gcp.T)
    acc = (np.argmax(fgp, axis = 1) == labels_test).sum()
    total = len(samples_test)

    samples_test=torch.from_numpy(refx)
    labels_test = refy
    fc=model_f(Variable(samples_test).float().to(device)).data.cpu().numpy()
    fcp=fc-f_mean
    fgp=np.dot(fcp,gcp.T)
    acc1 = (np.argmax(fgp, axis = 1) == labels_test).sum()
    total1 = len(samples_test)

    acc=(acc-acc1)/(total-total1)
#    print(acc)
    if acc > 0.5:
       if acc > (max(acclist)):
           paraf=model_f.state_dict()
           parag=model_g.state_dict()
#           print('changepara')
           finalacc=acc
    acclist.append(acc)
print(finalacc)

model_fa = Net_f().to(device)
model_fa.load_state_dict(paraf)

fstar_ref=model_f(Variable(samples_ref).float().to(device))
ftilde_ref = fstar_ref - torch.mean(fstar_ref,0)
py_ref=np.zeros(31)
for i in range(31):
    py_ref[i]=np.sum(labels_ref.numpy()==i)/labels_ref.shape[0]


fstar_trans=model_f(Variable(samples_trans).float().to(device))
ftilde_trans = fstar_trans - torch.mean(fstar_ref,0)
py_trans=np.zeros(31)
for i in range(31):
    py_trans[i]=np.sum(labels_trans.numpy()==i)/labels_trans.shape[0]
lambdaf= torch.mm(torch.t(ftilde_ref),ftilde_ref) / (ftilde_ref.size()[0]-1.)

v_ref=0
for i in range(31):
    a=ftilde_ref[labels_ref==i]
    v_ref+=(1/labels_ref.shape[0]*torch.trace(torch.mm(torch.inverse(lambdaf),(torch.mm(torch.t(a),a)/a.size()[0])))).item()
    v_ref-=(1/labels_ref.shape[0]*py_ref[i]*(torch.mm(torch.mean(a,0).reshape(1,64),torch.inverse(lambdaf)).mm(torch.t(torch.mean(a,0).reshape(1,64))))).item()
v_trans=0
for i in range(31):
    a=ftilde_trans[labels_trans==i]
    v_trans+=(1/labels_trans.shape[0]*py_trans[i]/py_ref[i]*torch.trace(torch.mm(torch.inverse(lambdaf),(torch.mm(torch.t(a),a)/a.size()[0])))).item()
    v_trans-=(1/labels_trans.shape[0]*py_trans[i]*py_trans[i]/py_ref[i]*(torch.mm(torch.mean(a,0).reshape(1,64),torch.inverse(lambdaf)).mm(torch.t(torch.mean(a,0).reshape(1,64))))).item()

h_ref=torch.zeros((31,64))
for i in range(31):
    a=ftilde_ref[labels_ref==i]
    h_ref[i]=py_ref[i]*torch.mean(a,0)

h_trans=torch.zeros((31,64))
for i in range(31):
    a=ftilde_trans[labels_trans==i]
    h_trans[i]=py_trans[i]*torch.mean(a,0)

h=torch.zeros((64,64))
for i in range(31):
    h+=1/py_ref[i]*torch.mm(torch.t(h_ref[i]-h_trans[i]).reshape(64,1),(h_ref[i]-h_trans[i]).reshape(1,64))

d=torch.trace(torch.mm(torch.inverse(lambdaf),h))

alpha=v_ref/(v_ref+v_trans+d.item())
print(alpha)
time_end=time.time()
print(time_end-time_start)
